{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 解读 `tvm.script.parser.evaluator`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"Unittests for tvm.script.parser.evaluator\"\"\"\n",
    "import pytest\n",
    "import tvm.testing\n",
    "from tvm.script.parser.core.diagnostics import Source\n",
    "from tvm.script.parser.core.evaluator import ExprEvaluator\n",
    "\n",
    "\n",
    "def _calc(expr, extra_vars=None):\n",
    "    if extra_vars is None:\n",
    "        extra_vars = {}\n",
    "    source = Source(expr)\n",
    "    mod_ast = source.as_ast()\n",
    "    mod_body_ast = mod_ast.body\n",
    "    expr_stmt_ast = mod_body_ast[0]\n",
    "    expr_ast = expr_stmt_ast.value\n",
    "    return ExprEvaluator.eval(None, extra_vars, expr_ast)\n",
    "\n",
    "\n",
    "def test_evaluator_basic():\n",
    "    assert _calc(\"1, 3.14, True, 'str'\") == (1, 3.14, True, \"str\")\n",
    "\n",
    "\n",
    "def test_evaluator_op():\n",
    "    assert _calc(\"1 + 2, 1 - 2, 1 * 2, 1 / 2\") == (3, -1, 2, 0.5)\n",
    "\n",
    "\n",
    "def test_evaluator_value_table():\n",
    "    res = _calc(\"a + b, a - b, a * b, a / b\", {\"a\": 1, \"b\": 2})\n",
    "    a, b = 1, 2\n",
    "    assert res == (a + b, a - b, a * b, a / b)\n",
    "\n",
    "\n",
    "def test_evaluator_func_call():\n",
    "    def func(a, b):\n",
    "        return a + b, a - b, a * b, a / b\n",
    "\n",
    "    assert _calc(\"func(1, 2)\", {\"func\": func}) == func(1, 2)\n",
    "\n",
    "\n",
    "def test_evaluator_slice():\n",
    "    res = _calc(\"a, a[1:], a[:5], a[1: 5], a[1: 5: 2]\", {\"a\": [1, 2, 3, 4, 5, 6]})\n",
    "    a = [1, 2, 3, 4, 5, 6]\n",
    "    assert res == (a, a[1:], a[:5], a[1:5], a[1:5:2])\n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    tvm.testing.main()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "xin",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
